Matplotlib and Seaborn
Reading time: ~40 minutes | Level: Intermediate-Advanced
The Debugging Scenario
Your model just shipped to production. Precision on the holdout set was 0.91 -- excellent. Two weeks later, the business team reports the model is making obviously wrong decisions on a new customer cohort. You pull up the metrics dashboard: accuracy is still 0.88. Nothing looks wrong.
The problem was visible all along -- but only in a plot you never made.
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# What the summary metric showed you
print(f"Accuracy: 0.88") # looks fine
# What a confusion matrix would have shown you immediately
y_true = np.array([0]*850 + [1]*150)
y_pred = np.array([0]*830 + [1]*20 + [0]*145 + [1]*5)
# 830 true negatives, 20 false positives, 145 false negatives, 5 true positives
# Recall on the positive class: 5 / (5 + 145) = 0.033 -- near zero
cm = confusion_matrix(y_true, y_pred)
fig, ax = plt.subplots(figsize=(5, 4))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
xticklabels=["Pred 0", "Pred 1"],
yticklabels=["True 0", "True 1"], ax=ax)
ax.set_title("Confusion Matrix -- Class Imbalance Disaster")
plt.tight_layout()
plt.savefig("confusion_matrix.png", dpi=150)
The confusion matrix would have screamed: your model learned to predict the majority class and ignore the minority entirely. This lesson builds the visual toolkit that catches these failures before they reach production.
Why This Matters
ML engineering is not just about fitting models -- it is about diagnosing them. The gap between a junior practitioner and a senior ML engineer is often the set of diagnostic plots they instinctively reach for.
You will use these skills constantly:
- Training/validation curves to diagnose overfitting and underfitting in real time
- Confusion matrices and ROC/PR curves to understand classifier behaviour beyond accuracy
- Feature importance plots to communicate what the model actually learned
- Distribution and correlation plots during EDA to catch data quality issues before they infect the model
- Embedding visualisations (t-SNE, UMAP) to understand what a neural network has learned about structure in your data
Seaborn makes statistical plots easy; Matplotlib gives you the control to make them publishable and embeddable in reports, papers, and dashboards.
1. Matplotlib Figure Anatomy
Before you can control a figure, you need to understand its object hierarchy.
Everything in a Matplotlib figure is an Artist -- a Python object that knows how to draw itself onto a renderer. The hierarchy is:
- Figure: the top-level container, represents the entire canvas
- Axes: one plotting area inside the Figure; a Figure can have many Axes
- Axis (singular): the X or Y axis on an Axes object -- holds ticks, labels, spines
- Artists: lines, patches, text, images, collections -- everything visible
import matplotlib.pyplot as plt
import numpy as np
# Always prefer the explicit object-oriented API in production code.
# plt.plot() uses an implicit "current axes" which breaks as soon as
# you have multiple subplots or build figures inside functions.
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 8))
# axes is a 2D numpy array of Axes objects
ax = axes[0, 0] # top-left subplot
ax.set_title("Top Left")
ax.set_xlabel("X label")
ax.set_ylabel("Y label")
# Access the underlying Artist objects directly
print(type(ax.xaxis)) # matplotlib.axis.XAxis
print(type(ax.title)) # matplotlib.text.Text
# Modify tick label font size -- two equivalent approaches
ax.tick_params(axis='both', labelsize=9)
for tick in ax.get_xticklabels():
tick.set_fontsize(9) # also valid; same underlying Artist
plt.tight_layout() # auto-adjust spacing between subplots
plt.savefig("anatomy.png", dpi=150, bbox_inches="tight")
plt.close(fig) # always close to free memory in scripts / long loops
The key rule: use the object-oriented API (fig, ax = plt.subplots()). The stateful plt.* API depends on hidden global state and causes subtle bugs when figures are built inside functions or loops.
2. Reproducible Style with rcParams
Reproducible figures require reproducible styles. Hard-coding font sizes on every plot leads to inconsistency across notebooks and reports. The right approach: one style configuration that all figures inherit.
import matplotlib as mpl
import matplotlib.pyplot as plt
# Set once -- at the top of your script, in a conftest.py, or a shared style module
mpl.rcParams.update({
"figure.dpi": 150,
"figure.figsize": (8, 5),
"axes.spines.top": False, # remove top spine for cleaner look
"axes.spines.right": False, # remove right spine
"axes.grid": True,
"grid.alpha": 0.3,
"grid.linestyle": "--",
"font.size": 11,
"axes.titlesize": 13,
"axes.labelsize": 11,
"xtick.labelsize": 9,
"ytick.labelsize": 9,
"legend.frameon": False,
"lines.linewidth": 2.0,
"savefig.bbox": "tight",
"savefig.dpi": 300, # publication quality for saved files
})
# OR use a named style sheet
plt.style.use("seaborn-v0_8-whitegrid") # prefix seaborn-v0_8-* needed in mpl >= 3.6
# Temporarily override for one block without affecting global state
with plt.style.context("dark_background"):
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [1, 4, 9])
plt.savefig("dark_plot.png")
plt.close(fig)
For multi-notebook projects, centralise your style in a shared module:
# ml_plot_style.py -- import this at the top of every analysis notebook
import matplotlib as mpl
STYLE = {
"figure.dpi": 150,
"figure.figsize": (8, 5),
"axes.spines.top": False,
"axes.spines.right": False,
"axes.grid": True,
"grid.alpha": 0.25,
"font.family": "DejaVu Sans",
"font.size": 11,
"axes.titlesize": 13,
"legend.frameon": False,
"savefig.dpi": 300,
"savefig.bbox": "tight",
}
def apply_style() -> None:
mpl.rcParams.update(STYLE)
3. Training and Validation Curves
The training curve is the most important diagnostic plot in supervised learning. It tells you immediately whether your model is overfitting, underfitting, or properly converging.
import matplotlib.pyplot as plt
import numpy as np
def plot_training_curves(
train_losses: list[float],
val_losses: list[float],
train_metric: list[float] | None = None,
val_metric: list[float] | None = None,
metric_name: str = "Accuracy",
save_path: str | None = None,
) -> plt.Figure:
"""
Two-panel training diagnostic: loss (always shown) + optional metric.
Returns the Figure so callers can embed it in reports or further annotate.
"""
n_panels = 2 if train_metric is not None else 1
fig, axes = plt.subplots(1, n_panels, figsize=(6 * n_panels, 4))
if n_panels == 1:
axes = [axes] # make iterable even for single panel
epochs = range(1, len(train_losses) + 1)
# --- Loss panel ---
ax = axes[0]
ax.plot(epochs, train_losses, label="Train loss", color="#2196F3")
ax.plot(epochs, val_losses, label="Val loss", color="#F44336", linestyle="--")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("Loss Curves")
ax.legend()
# Mark the best validation epoch with a vertical line + annotation
best_epoch = int(np.argmin(val_losses)) + 1
best_val = min(val_losses)
ax.axvline(x=best_epoch, color="gray", linestyle=":", alpha=0.7)
ax.annotate(
f"Best val\nepoch {best_epoch}\n{best_val:.4f}",
xy=(best_epoch, best_val),
xytext=(best_epoch + len(val_losses) * 0.05, best_val * 1.1),
arrowprops=dict(arrowstyle="->", color="gray"),
fontsize=8, color="gray",
)
# --- Optional metric panel ---
if train_metric is not None:
ax2 = axes[1]
ax2.plot(epochs, train_metric, label=f"Train {metric_name}", color="#2196F3")
ax2.plot(epochs, val_metric, label=f"Val {metric_name}", color="#F44336", linestyle="--")
ax2.set_xlabel("Epoch")
ax2.set_ylabel(metric_name)
ax2.set_title(f"{metric_name} Curves")
ax2.legend()
fig.suptitle("Training Diagnostics", fontsize=14, y=1.01)
plt.tight_layout()
if save_path:
fig.savefig(save_path)
return fig
# Simulate a typical overfitting scenario
np.random.seed(42)
n_epochs = 60
train_loss = np.exp(-np.linspace(0, 3.5, n_epochs)) + np.random.normal(0, 0.008, n_epochs)
val_loss = (np.exp(-np.linspace(0, 2.2, n_epochs))
+ np.linspace(0, 0.35, n_epochs)
+ np.random.normal(0, 0.015, n_epochs))
plot_training_curves(train_loss, val_loss, save_path="training_curves.png")
Reading the curve: when training loss keeps falling but validation loss starts rising, the model is memorising the training set. The best checkpoint is at the minimum validation loss, not the final epoch. Early stopping monitors val_loss and halts training automatically.
4. Confusion Matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import confusion_matrix
def plot_confusion_matrix(
y_true: np.ndarray,
y_pred: np.ndarray,
class_names: list[str],
normalize: bool = True,
save_path: str | None = None,
) -> plt.Figure:
"""
Confusion matrix with optional per-class normalisation.
Args:
normalize: if True, show proportions (row-normalised = recall per class).
Raw counts favour large classes and hide minority class failures.
"""
cm = confusion_matrix(y_true, y_pred)
if normalize:
# Divide each row by its sum: shows what fraction of each true class
# was correctly or incorrectly classified
cm_display = cm.astype(float) / cm.sum(axis=1, keepdims=True)
fmt = ".2f"
cbar_label = "Recall (proportion)"
else:
cm_display = cm
fmt = "d"
cbar_label = "Count"
n = len(class_names)
fig, ax = plt.subplots(figsize=(max(5, n * 1.2), max(4, n)))
sns.heatmap(
cm_display,
annot=True,
fmt=fmt,
cmap="Blues",
xticklabels=class_names,
yticklabels=class_names,
linewidths=0.5,
linecolor="white",
ax=ax,
vmin=0,
vmax=1 if normalize else None,
cbar_kws={"label": cbar_label},
)
ax.set_xlabel("Predicted label", fontsize=11)
ax.set_ylabel("True label", fontsize=11)
ax.set_title(
"Confusion Matrix (normalised)" if normalize else "Confusion Matrix (counts)",
fontsize=12
)
if n > 4:
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
plt.tight_layout()
if save_path:
fig.savefig(save_path)
return fig
# 4-class problem where the model confuses minority classes 2 and 3
rng = np.random.default_rng(0)
y_true = rng.choice([0, 1, 2, 3], size=600, p=[0.5, 0.3, 0.1, 0.1])
y_pred = y_true.copy()
mask2 = y_true == 2
mask3 = y_true == 3
# Class 2: model mostly predicts class 3
y_pred[mask2] = rng.choice([2, 3], size=mask2.sum(), p=[0.35, 0.65])
# Class 3: model has reasonable recall
y_pred[mask3] = rng.choice([2, 3], size=mask3.sum(), p=[0.25, 0.75])
plot_confusion_matrix(
y_true, y_pred,
class_names=["Benign", "Suspicious", "Malignant-A", "Malignant-B"],
normalize=True,
save_path="confusion_matrix_normalised.png",
)
Why normalise by row? Raw counts make majority classes look well-classified simply because they have more samples. Normalising by the true class count converts each cell to per-class recall, immediately exposing minority class failure.
5. ROC and Precision-Recall Curves
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
def plot_roc_pr(
y_true: np.ndarray,
y_scores: dict[str, np.ndarray],
save_path: str | None = None,
) -> plt.Figure:
"""
Side-by-side ROC and Precision-Recall curves for multiple models.
Args:
y_scores: dict mapping model name to predicted probability array (shape [n,])
"""
fig, (ax_roc, ax_pr) = plt.subplots(1, 2, figsize=(13, 5))
colors = ["#2196F3", "#F44336", "#4CAF50", "#FF9800", "#9C27B0"]
for (name, scores), color in zip(y_scores.items(), colors):
# ROC
fpr, tpr, _ = roc_curve(y_true, scores)
roc_auc = auc(fpr, tpr)
ax_roc.plot(fpr, tpr, label=f"{name} (AUC={roc_auc:.3f})", color=color)
# Precision-Recall
precision, recall, _ = precision_recall_curve(y_true, scores)
ap = average_precision_score(y_true, scores)
ax_pr.plot(recall, precision, label=f"{name} (AP={ap:.3f})", color=color)
# ROC cosmetics
ax_roc.plot([0, 1], [0, 1], "k--", alpha=0.4, linewidth=1, label="Random")
ax_roc.set_xlabel("False Positive Rate")
ax_roc.set_ylabel("True Positive Rate")
ax_roc.set_title("ROC Curve")
ax_roc.legend(loc="lower right", fontsize=9)
# PR baseline: a random classifier scores = class prevalence
prevalence = y_true.mean()
ax_pr.axhline(y=prevalence, color="k", linestyle="--", alpha=0.4,
linewidth=1, label=f"Random (AP={prevalence:.3f})")
ax_pr.set_xlabel("Recall")
ax_pr.set_ylabel("Precision")
ax_pr.set_title("Precision-Recall Curve")
ax_pr.legend(loc="upper right", fontsize=9)
ax_pr.set_xlim([0, 1])
ax_pr.set_ylim([0, 1.05])
plt.tight_layout()
if save_path:
fig.savefig(save_path)
return fig
# Simulate three models on an imbalanced dataset (10% positive class)
rng = np.random.default_rng(42)
n = 1000
y_true = rng.binomial(1, 0.10, size=n)
# Strong model: true signal + small noise
scores_strong = np.clip(y_true * 0.7 + rng.uniform(0, 0.3, n), 0, 1)
# Weak model: mostly noise
scores_weak = np.clip(y_true * 0.3 + rng.uniform(0, 0.7, n), 0, 1)
# Calibrated model: intermediate
scores_cal = np.clip(y_true * 0.55 + rng.normal(0.2, 0.15, n), 0, 1)
plot_roc_pr(
y_true,
{"Strong": scores_strong, "Weak": scores_weak, "Calibrated": scores_cal},
save_path="roc_pr_curves.png",
)
ROC vs PR: ROC-AUC can look impressive even when minority class recall is terrible, because it gives equal weight to all threshold operating points. PR-AUC (average precision) penalises heavily for missing the positive class. On imbalanced datasets, always report PR-AUC alongside ROC-AUC.
6. Feature Importance Plots
import matplotlib.pyplot as plt
import numpy as np
def plot_feature_importance(
importances: np.ndarray,
feature_names: list[str],
top_n: int = 20,
errors: np.ndarray | None = None,
title: str = "Feature Importance",
save_path: str | None = None,
) -> plt.Figure:
"""
Horizontal bar chart sorted by importance, with optional error bars.
Horizontal layout handles long feature names far better than vertical bars.
Error bars from permutation importance communicate uncertainty.
"""
# Sort descending, take top_n
idx = np.argsort(importances)[::-1][:top_n]
sorted_imp = importances[idx]
sorted_names = [feature_names[i] for i in idx]
sorted_err = errors[idx] if errors is not None else None
fig, ax = plt.subplots(figsize=(8, max(4, top_n * 0.38)))
# Reverse order so the most important feature appears at the top
y_pos = np.arange(len(sorted_imp))
ax.barh(
y_pos,
sorted_imp[::-1],
xerr=sorted_err[::-1] if sorted_err is not None else None,
color="#2196F3", alpha=0.85,
error_kw=dict(ecolor="#78909C", capsize=3, linewidth=1),
)
ax.set_yticks(y_pos)
ax.set_yticklabels(sorted_names[::-1], fontsize=9)
ax.set_xlabel("Importance score")
ax.set_title(title)
# Annotate the top feature value
ax.text(
sorted_imp[0] * 0.97, y_pos[-1],
f"{sorted_imp[0]:.4f}",
va="center", ha="right", fontsize=8, color="white", fontweight="bold",
)
plt.tight_layout()
if save_path:
fig.savefig(save_path)
return fig
# Simulate Random Forest importances: a few dominant, long tail near zero
rng = np.random.default_rng(7)
n_features = 35
names = [f"feature_{i:02d}" for i in range(n_features)]
raw = rng.exponential(scale=0.04, size=n_features)
raw[:3] = [0.28, 0.17, 0.11] # three dominant features
importances = raw / raw.sum()
errors = rng.uniform(0.001, 0.018, size=n_features)
plot_feature_importance(
importances, names,
top_n=15,
errors=errors,
title="Random Forest Feature Importances (top 15)",
save_path="feature_importance.png",
)
Note on importance types: tree-based impurity importance (the default in sklearn) is biased toward high-cardinality features. Permutation importance (from sklearn.inspection.permutation_importance) is more reliable -- always plot both and compare.
7. Distribution Analysis for EDA
Understanding your data distribution before modelling prevents an entire class of bugs: outliers inflating regression targets, skewed features hurting gradient descent, bimodal distributions that signal hidden subpopulations.
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy import stats
def plot_feature_distributions(
data: np.ndarray,
feature_names: list[str],
ncols: int = 3,
save_path: str | None = None,
) -> plt.Figure:
"""
Grid of histograms + KDE for each feature.
Annotates skewness and excess kurtosis.
Yellow background flags heavily skewed features that likely need a log transform.
"""
n = len(feature_names)
nrows = (n + ncols - 1) // ncols
fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 3.5 * nrows))
axes = axes.flatten()
for i, name in enumerate(feature_names):
ax = axes[i]
col = data[:, i]
skew = stats.skew(col)
kurt = stats.kurtosis(col) # excess kurtosis (Fisher, zero = normal)
sns.histplot(col, kde=True, ax=ax, color="#5C6BC0", bins=30, alpha=0.6)
ax.set_title(f"{name}\nskew={skew:.2f} kurt={kurt:.2f}", fontsize=9)
ax.set_xlabel("")
ax.set_ylabel("Count")
# Warn: |skew| > 1 means the feature likely needs a log or Box-Cox transform
if abs(skew) > 1:
ax.set_facecolor("#FFF8E1") # light yellow tint
for j in range(n, len(axes)):
axes[j].set_visible(False)
fig.suptitle("Feature Distribution Analysis", fontsize=13, y=1.01)
plt.tight_layout()
if save_path:
fig.savefig(save_path)
return fig
# Mix of distribution families to illustrate the diagnostics
rng = np.random.default_rng(0)
n_samples = 600
data = np.column_stack([
rng.normal(0, 1, n_samples), # normal
rng.lognormal(0, 1, n_samples), # right-skewed
np.concatenate([rng.normal(-2, 0.5, 300),
rng.normal( 2, 0.5, 300)]), # bimodal
rng.exponential(2, n_samples), # exponential
rng.uniform(-3, 3, n_samples), # uniform
rng.standard_t(df=3, size=n_samples), # heavy tails
])
names = ["normal", "log-normal", "bimodal", "exponential", "uniform", "t(df=3)"]
plot_feature_distributions(data, names, ncols=3, save_path="distributions.png")
8. Correlation Heatmap
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
def plot_correlation_heatmap(
df: pd.DataFrame,
method: str = "pearson",
mask_upper: bool = True,
save_path: str | None = None,
) -> plt.Figure:
"""
Triangular correlation heatmap using a diverging colour map.
Args:
mask_upper: hides the redundant upper triangle; easier to scan.
method: 'pearson' (linear), 'spearman' (rank), or 'kendall'.
"""
corr = df.corr(method=method)
n = len(corr)
mask = np.zeros_like(corr, dtype=bool)
if mask_upper:
# k=1 excludes the diagonal; k=0 would mask the diagonal too
mask[np.triu_indices_from(mask, k=1)] = True
fig, ax = plt.subplots(figsize=(max(6, n * 0.75), max(5, n * 0.7)))
sns.heatmap(
corr,
mask=mask,
cmap="RdBu_r", # red = positive, blue = negative, white = zero
vmin=-1, vmax=1,
center=0,
annot=True, fmt=".2f", annot_kws={"size": 8},
square=True,
linewidths=0.5, linecolor="white",
ax=ax,
cbar_kws={"shrink": 0.8, "label": f"{method.capitalize()} r"},
)
ax.set_title(f"Feature Correlation ({method})", fontsize=12)
ax.tick_params(axis="x", rotation=45)
ax.tick_params(axis="y", rotation=0)
plt.tight_layout()
if save_path:
fig.savefig(save_path)
return fig
# Synthetic DataFrame with known correlations
rng = np.random.default_rng(42)
n = 400
x1 = rng.normal(0, 1, n)
x2 = 0.92 * x1 + rng.normal(0, 0.15, n) # near-perfect positive correlation
x3 = -0.65 * x1 + rng.normal(0, 0.5, n) # moderate negative correlation
x4 = rng.normal(0, 1, n) # independent
x5 = 0.40 * x4 + rng.normal(0, 0.85, n) # weak positive correlation
df = pd.DataFrame({"income": x1, "spending": x2, "savings": x3,
"age": x4, "tenure": x5})
plot_correlation_heatmap(df, method="pearson", save_path="correlation_heatmap.png")
High correlations between features indicate multicollinearity, which inflates coefficient variance in linear models and makes interpretation meaningless. The heatmap surfaces this in seconds.
9. Embedding Visualisation: t-SNE and UMAP
High-dimensional embeddings -- from neural networks, sentence transformers, or clustering algorithms -- need 2D projection before you can reason about what the model has learned.
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
def plot_embeddings(
embeddings: np.ndarray,
labels: np.ndarray,
class_names: list[str] | None = None,
method: str = "tsne",
perplexity: int = 30,
title: str | None = None,
save_path: str | None = None,
) -> plt.Figure:
"""
Projects embeddings to 2D with t-SNE or UMAP and colours by class label.
Notes on t-SNE:
- Standardise embeddings before projecting; t-SNE is scale-sensitive.
- Use init='pca' for more stable, reproducible layouts.
- Distances BETWEEN clusters are not meaningful -- only within-cluster
tightness and cluster separation matter.
- Try perplexity in [5, 50]. Run multiple times to check stability.
For UMAP (faster, better global structure): pip install umap-learn
"""
emb_scaled = StandardScaler().fit_transform(embeddings)
if method == "tsne":
coords = TSNE(
n_components=2, perplexity=perplexity,
random_state=42, n_iter=1000, init="pca",
).fit_transform(emb_scaled)
elif method == "umap":
try:
import umap # type: ignore
coords = umap.UMAP(n_components=2, random_state=42).fit_transform(emb_scaled)
except ImportError:
raise ImportError("Run: pip install umap-learn")
else:
raise ValueError(f"Unknown method: {method!r}. Use 'tsne' or 'umap'.")
unique_classes = np.unique(labels)
cmap = plt.cm.get_cmap("tab10", len(unique_classes))
fig, ax = plt.subplots(figsize=(8, 6))
for cls_idx in unique_classes:
mask = labels == cls_idx
name = class_names[cls_idx] if class_names else str(cls_idx)
ax.scatter(
coords[mask, 0], coords[mask, 1],
s=14, alpha=0.75, color=cmap(cls_idx), label=name, linewidths=0,
)
ax.set_title(title or f"Embedding Projection ({method.upper()})", fontsize=12)
ax.set_xlabel("Component 1")
ax.set_ylabel("Component 2")
ax.legend(markerscale=2, fontsize=9, loc="best")
# No tick values -- they have no meaningful units in t-SNE/UMAP space
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
if save_path:
fig.savefig(save_path)
return fig
# 4-class embeddings in 64-dimensional space
rng = np.random.default_rng(0)
n_per_class = 200
centres = rng.normal(0, 4, size=(4, 64))
X = np.vstack([
centres[i] + rng.normal(0, 1.2, size=(n_per_class, 64))
for i in range(4)
])
y = np.repeat(np.arange(4), n_per_class)
plot_embeddings(
X, y,
class_names=["Positive", "Negative", "Neutral", "Mixed"],
method="tsne", perplexity=40,
save_path="embeddings_tsne.png",
)
10. Publication-Quality Figure Saving
import matplotlib.pyplot as plt
from pathlib import Path
def save_figure(
fig: plt.Figure,
path: str | Path,
formats: list[str] | None = None,
dpi: int = 300,
) -> None:
"""
Save a figure in one or more formats.
Recommended by use case:
PDF -- LaTeX paper inclusion (vector, infinitely scalable)
SVG -- Web / presentations (vector, editable in Inkscape)
PNG -- Web embeds, Slack/Notion (raster, 150-300 dpi)
EPS -- Legacy journal submission format
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
formats = formats or [path.suffix.lstrip(".")]
for fmt in formats:
out = path.with_suffix(f".{fmt}")
fig.savefig(
out,
dpi=dpi,
bbox_inches="tight", # never clip axis labels or titles
pad_inches=0.05,
transparent=(fmt in {"svg", "pdf"}), # transparent background for overlay
metadata={"Creator": "EngineersOfAI", "Title": str(path.stem)},
)
print(f"Saved: {out} ({out.stat().st_size / 1024:.1f} KB)")
# Usage
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot([1, 2, 3, 4], [1, 4, 2, 3], marker="o")
ax.set_title("Validation Loss")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
save_figure(fig, "output/val_loss", formats=["png", "pdf", "svg"], dpi=300)
plt.close(fig)
11. Seaborn for Statistical Plots
Seaborn wraps Matplotlib for statistical visualisations. It integrates with DataFrames, handles grouping automatically, and computes bootstrapped confidence intervals by default.
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
# Synthetic cross-validation results DataFrame
rng = np.random.default_rng(1)
df = pd.DataFrame({
"model": ["LR", "RF", "XGB", "MLP"] * 25,
"fold": list(range(25)) * 4,
"accuracy": np.concatenate([
rng.normal(0.81, 0.022, 25),
rng.normal(0.87, 0.015, 25),
rng.normal(0.89, 0.012, 25),
rng.normal(0.86, 0.025, 25),
]),
"split": rng.choice(["Train", "Val"], size=100, p=[0.5, 0.5]),
})
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
# Box plot with overlaid strip plot
sns.boxplot(
data=df[df["split"] == "Val"],
x="model", y="accuracy",
order=["LR", "RF", "MLP", "XGB"],
palette="Blues", width=0.5, ax=axes[0],
)
sns.stripplot(
data=df[df["split"] == "Val"],
x="model", y="accuracy",
order=["LR", "RF", "MLP", "XGB"],
color="black", alpha=0.4, size=4, ax=axes[0],
)
axes[0].set_title("Model Accuracy Distribution (Val, 25 folds)")
axes[0].set_ylabel("Accuracy")
# Split violin: one half per split
sns.violinplot(
data=df,
x="model", y="accuracy", hue="split",
split=True, # each side of the violin shows one hue level
inner="quartile", # quartile lines inside the violin body
palette={"Train": "#90CAF9", "Val": "#EF9A9A"},
ax=axes[1],
)
axes[1].set_title("Train vs Val Accuracy Distribution")
axes[1].set_ylabel("Accuracy")
axes[1].legend(title="Split", loc="lower right")
plt.tight_layout()
plt.savefig("seaborn_model_comparison.png", dpi=150)
plt.close(fig)
12. Common Mistakes
Mistake 1: Mutating rcParams without restoring them
# BAD: pollutes global state for every subsequent figure in the session
plt.rcParams["figure.figsize"] = (12, 8)
# GOOD: scope changes to a context manager
with plt.rc_context({"figure.figsize": (12, 8)}):
fig, ax = plt.subplots()
ax.plot([1, 2, 3], [3, 2, 1])
plt.savefig("large_figure.png")
Mistake 2: Calling plt.show() inside library functions
# BAD: blocks execution and prevents embedding in multi-panel figures
def plot_loss(train, val):
plt.plot(train)
plt.plot(val)
plt.show() # destroys the figure; cannot be composed
# GOOD: return the Figure; let the caller decide
def plot_loss(train, val) -> plt.Figure:
fig, ax = plt.subplots()
ax.plot(train, label="train")
ax.plot(val, label="val")
ax.legend()
return fig
fig = plot_loss(train_losses, val_losses)
fig.savefig("loss.png")
plt.close(fig)
Mistake 3: Leaking Figure objects in loops
# BAD: creates a new Figure every iteration without closing it
for epoch in range(1000):
loss = train_one_epoch(model, loader)
plt.plot(history["loss"])
plt.savefig(f"loss_{epoch}.png")
# Figure never closed -- 1000 Figure objects accumulate in memory
# GOOD: close the Figure after every save
for epoch in range(1000):
fig, ax = plt.subplots()
ax.plot(history["loss"])
fig.savefig(f"loss_{epoch}.png")
plt.close(fig) # release memory immediately
Mistake 4: Interpreting t-SNE cluster distances as meaningful
t-SNE optimises local neighbourhood structure only. Global distances between clusters differ across runs and change with perplexity. Never write "cluster A is closer to cluster B than to cluster C" based on a t-SNE plot. Use actual distance metrics in the original embedding space for that statement.
Key Takeaways
- Use the object-oriented API (
fig, ax = plt.subplots()) everywhere. The statefulplt.*API causes subtle bugs inside functions and loops. - Set rcParams once per project using a shared style module; hard-coding sizes per plot creates inconsistency.
- Training curves are your first diagnostic: train loss falling + val loss rising = overfitting; both high = underfitting.
- Normalise confusion matrices by row to expose per-class recall; raw counts hide minority class failures.
- On imbalanced datasets, report PR-AUC alongside ROC-AUC. A model with high ROC-AUC can have dismal recall on the minority class.
- Seaborn is for statistical plots with DataFrames; Matplotlib is for full control. They compose: every Seaborn plot returns a Matplotlib Axes.
- t-SNE and UMAP project local structure, not distance. Remove axis ticks and warn viewers that inter-cluster distances are not interpretable.
- Always call
plt.close(fig)in scripts and loops. Unclosed Figure objects leak memory. - Save publication figures as PDF for LaTeX and PNG at 300 dpi for web.
Practice Problems
Problem 1 -- Learning Rate Diagnostic Train a small neural network with three learning rates: 0.001, 0.01, and 0.1. Plot all six training/validation loss curves on a single figure (two lines per LR, different colours per LR, dashed for validation). Add a legend, mark the best validation epoch per curve with a vertical annotation, and save as PDF.
Problem 2 -- Multi-Class ROC
For a 5-class classifier, compute one-vs-rest ROC curves for each class and plot them all on a single Axes with a legend showing class name and AUC. Add a shaded region under the macro-average curve. Hint: use sklearn.preprocessing.label_binarize to create one-vs-rest binary labels.
Problem 3 -- Distribution Shift Detection
Given two DataFrames df_train and df_prod, plot overlaid KDE curves for each feature comparing the training distribution to the production distribution (blue = train, red = prod). Annotate each subplot with the KL-divergence between the two distributions. Flag features where KL-divergence exceeds 0.1 with a red background. This is the foundation of a data drift monitoring dashboard.
Problem 4 -- Correlation-Filtered Heatmap Build a correlation heatmap that only annotates cells where |correlation| > 0.5. For cells below that threshold, replace the annotation with an empty string and reduce the cell alpha to 0.15. This declutters large matrices and draws attention only to actionable correlations.
Problem 5 -- Animated Training Dashboard
Using matplotlib.animation.FuncAnimation, build a live-updating dashboard that replots the loss curve and a 2-class confusion matrix after every 5 epochs during training. Save the animation as a GIF. Hint: update the heatmap data in place using set_data() on the QuadMesh object returned by sns.heatmap to avoid redrawing the full Axes every frame.
